
package com.jstarcraft.ai.jsat.math;

import java.io.Serializable;
import java.util.function.BinaryOperator;

/**
 *
 * This class provides a means of updating summary statistics as each new data
 * point is added. The data points are not stored, and values are updated with
 * an online algorithm. <br>
 * As such, this class has constant memory usage, regardless of how many values
 * are added. But the results may not be as numerically accurate, and can
 * degrade badly given specific data sequences.
 * 
 * @author Edward Raff
 */
public class OnLineStatistics implements Serializable, Cloneable, BinaryOperator<OnLineStatistics> {

    private static final long serialVersionUID = -4286295481362462983L;
    /**
     * The current mean
     */
    private double mean;
    /**
     * The current number of samples seen
     */
    private double n;

    // Intermediat value updated at each step, variance computed from it
    private double m2, m3, m4;

    private Double min, max;

    /**
     * Creates a new set of statistical counts with no information
     */
    public OnLineStatistics() {
        this(0, 0, 0, 0, 0);
    }

    /**
     * Creates a new set of statistical counts with these initial values, and can
     * then be updated in an online fashion
     * 
     * @param n        the total weight of all data points added. This value must be
     *                 non negative
     * @param mean     the starting mean. If <tt>n</tt> is zero, this value will be
     *                 ignored.
     * @param variance the starting variance. If <tt>n</tt> is zero, this value will
     *                 be ignored.
     * @param skew     the starting skewness. If <tt>n</tt> is zero, this value will
     *                 be ignored.
     * @param kurt     the starting kurtosis. If <tt>n</tt> is zero, this value will
     *                 be ignored.
     * @throws ArithmeticException if <tt>n</tt> is a negative number
     */
    public OnLineStatistics(double n, double mean, double variance, double skew, double kurt) {
        if (n < 0)
            throw new ArithmeticException("Can not have a negative set of weights");
        this.n = n;
        if (n != 0) {
            this.mean = mean;
            this.m2 = variance * (n - 1);
            this.m3 = Math.pow(m2, 3.0 / 2.0) * skew / Math.sqrt(n);
            this.m4 = (3 + kurt) * m2 * m2 / n;
        } else
            this.mean = m2 = m3 = m4 = 0;
        min = max = null;
    }

    private OnLineStatistics(double n, double mean, double m2, double m3, double m4, Double min, Double max) {
        this.n = n;
        this.mean = mean;
        this.m2 = m2;
        this.m3 = m3;
        this.m4 = m4;
        this.min = min;
        this.max = max;
    }

    /**
     * Copy Constructor
     * 
     * @param other the version to make a copy of
     */
    public OnLineStatistics(OnLineStatistics other) {
        this(other.n, other.mean, other.m2, other.m3, other.m4, other.min, other.max);
    }

    /**
     * Adds a data sample with unit weight to the counts.
     * 
     * @param x the data value to add
     */
    public void add(double x) {
        add(x, 1.0);
    }

    /**
     * Adds a data sample the the counts with the provided weight of influence.
     * 
     * @param x      the data value to add
     * @param weight the weight to give the value
     * @throws ArithmeticException if a negative weight is given
     */
    public void add(double x, double weight) {
        // See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance

        if (weight < 0)
            throw new ArithmeticException("Can not add a negative weight");
        else if (weight == 0)
            return;

        double n1 = n;
        n += weight;
        double delta = x - mean;
        double delta_n = delta * weight / n;
        double delta_n2 = delta_n * delta_n;
        double term1 = delta * delta_n * n1;

        mean += delta_n;
        m4 += term1 * delta_n2 * (n * n - 3 * n + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3;
        m3 += term1 * delta_n * (n - 2) - 3 * delta_n * m2;
        m2 += weight * delta * (x - mean);

        if (min == null)
            min = max = x;
        else {
            min = Math.min(min, x);
            max = Math.max(max, x);
        }
    }

    /**
     * Effectively removes a sample with the given value and weight from the total.
     * Removing values that have not been added may yield results that have no
     * meaning <br>
     * <br>
     * NOTE: {@link #getSkewness() } and {@link #getKurtosis() } are not currently
     * updated correctly
     * 
     * @param x      the value of the sample
     * @param weight the weight of the sample
     * @throws ArithmeticException if a negative weight is given
     */
    public void remove(double x, double weight) {
        if (weight < 0)
            throw new ArithmeticException("Can not remove a negative weight");
        else if (weight == 0)
            return;

        double n1 = n;
        n -= weight;
        double delta = x - mean;
        double delta_n = delta * weight / n;
        double delta_n2 = delta_n * delta_n;
        double term1 = delta * delta_n * n1;

        mean -= delta_n;

        m2 -= weight * delta * (x - mean);
        // TODO m3 and m4 arent getting updated correctly
        m3 -= term1 * delta_n * (n - 2 + weight) - 3 * delta_n * m2;
        m4 -= term1 * delta_n2 * (n * n - 3 * n + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3;
    }

    /**
     * Computes a new set of statistics that is the equivalent of having removed all
     * observations in {@code B} from {@code A}. <br>
     * NOTE: removing statistics is not as numerically stable. The values of the 3rd
     * and 4th moments {@link #getSkewness() } and {@link #getKurtosis() } will be
     * inaccurate for many inputs. The {@link #getMin() min} and {@link #getMax()
     * max} can not be determined in this setting, and will not be altered.
     * 
     * @param A the first set of statistics, which must have a larger value for
     *          {@link #getSumOfWeights() } than {@code B}
     * @param B the set of statistics to remove from {@code A}.
     * @return a new set of statistics that is the removal of {@code B} from
     *         {@code A}
     */
    public static OnLineStatistics remove(OnLineStatistics A, OnLineStatistics B) {
        OnLineStatistics toRet = A.clone();
        toRet.remove(B);
        return toRet;
    }

    /**
     * Removes from this set of statistics the observations that where collected in
     * {@code B}.<br>
     * NOTE: removing statistics is not as numerically stable. The values of the 3rd
     * and 4th moments {@link #getSkewness() } and {@link #getKurtosis() } will be
     * inaccurate for many inputs. The {@link #getMin() min} and {@link #getMax()
     * max} can not be determined in this setting, and will not be altered.
     * 
     * @param B the set of statistics to remove
     */
    public void remove(OnLineStatistics B) {
        final OnLineStatistics A = this;
        // XXX double compare.
        if (A.n == B.n) {
            n = mean = m2 = m3 = m4 = 0;
            min = max = null;
            return;
        } else if (B.n == 0)
            return;// removed nothing!
        else if (A.n < B.n)
            throw new ArithmeticException("Can not have negative samples");

        double nX = A.n - B.n;
        double nXsqrd = nX * nX;
        double nAnB = B.n * A.n;
        double AnSqrd = A.n * A.n;
        double BnSqrd = B.n * B.n;

        double delta = B.mean - A.mean;
        double deltaSqrd = delta * delta;
        double deltaCbd = deltaSqrd * delta;
        double deltaQad = deltaSqrd * deltaSqrd;
        double newMean = (A.n * A.mean - B.n * B.mean) / (A.n - B.n);
        double newM2 = A.m2 - B.m2 - deltaSqrd / nX * nAnB;
        double newM3 = A.m3 - B.m3 - deltaCbd * nAnB * (A.n - B.n) / nXsqrd - 3 * delta * (A.n * B.m2 - B.n * A.m2) / nX;
        double newM4 = A.m4 - B.m4 - deltaQad * (nAnB * (AnSqrd - nAnB + BnSqrd) / (nXsqrd * nX)) - 6 * deltaSqrd * (AnSqrd * B.m2 - BnSqrd * A.m2) / nXsqrd - 4 * delta * (A.n * B.m3 - B.n * A.m3) / nX;

        this.n = nX;
        this.mean = newMean;
        this.m2 = newM2;
        this.m3 = newM3;
        this.m4 = newM4;
    }

    /**
     * Computes a new set of counts that is the sum of the counts from the given
     * distributions. <br>
     * <br>
     * NOTE: Adding two statistics is not as numerically stable. If A and B have
     * values of similar size and scale, the values of the 3rd and 4th moments
     * {@link #getSkewness() } and {@link #getKurtosis() } will suffer from
     * catastrophic cancellations, and may not be as accurate.
     * 
     * @param A the first set of statistics
     * @param B the second set of statistics
     * @return a new set of statistics that is the addition of the two.
     */
    public static OnLineStatistics add(OnLineStatistics A, OnLineStatistics B) {
        OnLineStatistics toRet = A.clone();
        toRet.add(B);
        return toRet;
    }

    /**
     * Adds to the current statistics all the samples that were collected in
     * {@code B}. <br>
     * NOTE: Adding two statistics is not as numerically stable. If A and B have
     * values of similar size and scale, the values of the 3rd and 4th moments
     * {@link #getSkewness() } and {@link #getKurtosis() } will suffer from
     * catastrophic cancellations, and may not be as accurate.
     * 
     * @param B the set of statistics to add to this set
     */
    public void add(OnLineStatistics B) {
        final OnLineStatistics A = this;
        // XXX double compare.
        if (A.n == B.n && B.n == 0)
            return;// nothing to do!
        else if (B.n == 0)
            return;// still nothing!
        else if (A.n == 0) {
            this.n = B.n;
            this.mean = B.mean;
            this.m2 = B.m2;
            this.m3 = B.m3;
            this.m4 = B.m4;
            this.min = B.min;
            this.max = B.max;
            return;
        }

        double nX = B.n + A.n;
        double nXsqrd = nX * nX;
        double nAnB = B.n * A.n;
        double AnSqrd = A.n * A.n;
        double BnSqrd = B.n * B.n;

        double delta = B.mean - A.mean;
        double deltaSqrd = delta * delta;
        double deltaCbd = deltaSqrd * delta;
        double deltaQad = deltaSqrd * deltaSqrd;
        double newMean = (A.n * A.mean + B.n * B.mean) / (A.n + B.n);
        double newM2 = A.m2 + B.m2 + deltaSqrd / nX * nAnB;
        double newM3 = A.m3 + B.m3 + deltaCbd * nAnB * (A.n - B.n) / nXsqrd + 3 * delta * (A.n * B.m2 - B.n * A.m2) / nX;
        double newM4 = A.m4 + B.m4 + deltaQad * (nAnB * (AnSqrd - nAnB + BnSqrd) / (nXsqrd * nX)) + 6 * deltaSqrd * (AnSqrd * B.m2 + BnSqrd * A.m2) / nXsqrd + 4 * delta * (A.n * B.m3 - B.n * A.m3) / nX;

        this.n = nX;
        this.mean = newMean;
        this.m2 = newM2;
        this.m3 = newM3;
        this.m4 = newM4;
        this.min = Math.min(A.min, B.min);
        this.max = Math.max(A.max, B.max);
    }

    @Override
    public OnLineStatistics clone() {
        return new OnLineStatistics(n, mean, m2, m3, m4, min, max);
    }

    /**
     * Returns the sum of the weights for all data points added to the statistics.
     * If all weights were 1, then this value is the number of data points added.
     * 
     * @return the sum of weights for every point currently contained in the
     *         statistics.
     */
    public double getSumOfWeights() {
        return n;
    }

    public double getMean() {
        return mean;
    }

    /**
     * Computes the population variance
     * 
     * @return the variance of the data seen
     */
    public double getVarance() {
        return m2 / (n + 1e-15);// USED to be unbiased est, but dosn't work for weighted data when the weights
                                // may be <= 1. So use biased.
    }

    public double getStandardDeviation() {
        return Math.sqrt(getVarance());
    }

    public double getSkewness() {
        return Math.sqrt(n) * m3 / Math.pow(m2, 3.0 / 2.0);
    }

    public double getKurtosis() {
        return (n * m4) / (m2 * m2) - 3;
    }

    public double getMin() {
        return min;
    }

    public double getMax() {
        return max;
    }

    @Override
    public OnLineStatistics apply(OnLineStatistics t, OnLineStatistics u) {
        if (t == null)
            return u;
        else if (u == null)
            return t;
        // else, boh are non-null
        return OnLineStatistics.add(t, u);
    }

}
